{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 之前发的时候，cross_entropy_loss直接用的unsloth的，我结合unsloth和flash-attn库的，写一个更快的cross_entropy，并且支持vocab并行，无缝衔接megatron框架。但是不支持scale和smooth等功能，基本满足大多数人的训练需求"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/torch/lib/python3.10/site-packages/torch/nn/_reduction.py:51: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n"
     ]
    }
   ],
   "source": [
    "import triton\n",
    "import triton.language as tl\n",
    "import torch\n",
    "import torch.distributed as dist\n",
    "from copy import deepcopy\n",
    "import os\n",
    "from mdy_triton.core import fast_cross_entropy_loss\n",
    "os.environ['TRITON_PRINT_AUTOTUNING'] = '1'\n",
    "torch_ce = torch.nn.CrossEntropyLoss(reduce=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "@triton.jit\n",
    "def _cross_entropy_fwd_kernel(LOGITS, LABELS, LOSSES, LOGSUMEXP,\n",
    "                             vocab_start_index, row_stride, \n",
    "                             M, N, SPLIT, BLOCK_SIZE: tl.constexpr, \n",
    "                             ):\n",
    "    row_idx = tl.program_id(0)\n",
    "    row_stride = row_stride.to(tl.int64)\n",
    "    label_idx = tl.load(LABELS + row_idx).to(tl.int32)\n",
    "    if (label_idx != -100):\n",
    "        LOGITS += row_idx * row_stride\n",
    "        base_cols = tl.arange(0, BLOCK_SIZE)\n",
    "        m_i = -float(\"inf\")\n",
    "        l_i = 0.0\n",
    "        for start_n in tl.range(0, N, BLOCK_SIZE):\n",
    "            cols = start_n + base_cols\n",
    "            mask = cols < N\n",
    "            logits = tl.load(LOGITS+cols, mask=mask, other=-float('inf')).to(tl.float32)\n",
    "            m_ij = tl.max(logits)\n",
    "            new_m_i = tl.maximum(m_i, m_ij)\n",
    "            l_i = l_i * tl.exp(m_i - new_m_i) + tl.sum(tl.exp(logits - new_m_i))\n",
    "            m_i = new_m_i\n",
    "        lse = tl.log(l_i) + m_i\n",
    "\n",
    "        if (label_idx >= vocab_start_index) and (label_idx < (vocab_start_index + N)):\n",
    "            x = -1.0 * tl.load(LOGITS+label_idx-vocab_start_index).to(tl.float32)\n",
    "            if not SPLIT:\n",
    "                loss = lse + x\n",
    "                tl.store(LOSSES+row_idx, loss)\n",
    "            else:\n",
    "                tl.store(LOSSES+row_idx, x)\n",
    "        tl.store(LOGSUMEXP+row_idx, lse)\n",
    "\n",
    "# @triton.autotune([triton.Config({'BLOCK_SIZE': bsz}, num_stages=ns, num_warps=nw)\n",
    "#                  for bsz in [8192*2,8192*4, 8192*8]\n",
    "#                  for ns in [1, 2,3,4]\n",
    "#                  for nw in [16, 32]\n",
    "#                  ], key=['M', 'N']\n",
    "#                  )\n",
    "@triton.jit\n",
    "def _cross_entropy_bwd_kernel(DLOSSES, DLOGITS,\n",
    "                            LOGITS, LABELS, LOGSUMEXP,\n",
    "                             vocab_start_index, row_stride, \n",
    "                             M, N,  INPLACE,\n",
    "                             BLOCK_SIZE: tl.constexpr,\n",
    "                             ):\n",
    "    row_idx = tl.program_id(0)\n",
    "    LABELS += row_idx\n",
    "    label_idx = tl.load(LABELS).to(tl.int32)\n",
    "    row_stride = row_stride.to(tl.int64)\n",
    "    if (label_idx != -100):\n",
    "        # label_idx -= vocab_start_index\n",
    "        LOGITS += row_idx * row_stride\n",
    "        DLOGITS += row_idx * row_stride\n",
    "        LOGSUMEXP += row_idx\n",
    "        DLOSSES += row_idx\n",
    "        lse = tl.load(LOGSUMEXP)\n",
    "        dloss = tl.load(DLOSSES).to(tl.float32)\n",
    "        base_cols = tl.arange(0, BLOCK_SIZE)\n",
    "        for start_n in tl.range(0, N, BLOCK_SIZE):\n",
    "            cols = start_n + base_cols\n",
    "            mask = cols < N\n",
    "            logits = tl.load(LOGITS+cols, mask=mask, other=0.).to(tl.float32)\n",
    "            probs = tl.exp(logits - lse)\n",
    "            tmp = vocab_start_index + start_n\n",
    "            if (label_idx >= tmp) and (label_idx < (tmp + BLOCK_SIZE)):\n",
    "                probs = tl.where(cols+vocab_start_index != label_idx, probs, probs-1.)\n",
    "            tl.store(DLOGITS+cols, probs * dloss, mask=mask)\n",
    "    elif INPLACE:\n",
    "        DLOGITS += row_idx * row_stride\n",
    "        base_cols = tl.arange(0, BLOCK_SIZE)\n",
    "        zeros = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)\n",
    "        for start_n in tl.range(0, N, BLOCK_SIZE):\n",
    "            cols = start_n + base_cols\n",
    "            mask = cols < N\n",
    "            tl.store(DLOGITS+cols, zeros, mask=mask)\n",
    "\n",
    "class _FastCrossEntropyLoss(torch.autograd.Function):\n",
    "\n",
    "    @staticmethod\n",
    "    def forward(ctx, logits, labels, inplace):\n",
    "        ctx.input_shape = logits.shape\n",
    "        tp_rank = 0\n",
    "        tp_size = 1\n",
    "        tp_group = None\n",
    "        N = ctx.input_shape[-1]\n",
    "        logits = logits.view(-1, N)\n",
    "        M = logits.size(0)\n",
    "        losses = torch.zeros(*ctx.input_shape[:-1], device=logits.device, dtype=torch.float32)\n",
    "        split = tp_size > 1\n",
    "        vocab_start_index = N * tp_rank\n",
    "        logsumexp = torch.zeros(M, device=logits.device, dtype=torch.float32)\n",
    "        # print(logsumexp.stride(), losses.stride())\n",
    "        _cross_entropy_fwd_kernel[(M,)](logits, labels, losses, logsumexp,\n",
    "                                                    vocab_start_index, logits.stride(0),\n",
    "                                                    M, N, split,\n",
    "                                                    BLOCK_SIZE=4096, num_warps=4, num_stages=3\n",
    "                                                    )\n",
    "        if tp_size>1:\n",
    "            lse_allgather = torch.empty(tp_size, M, dtype=logsumexp.dtype, device=logsumexp.device)\n",
    "            torch.distributed.all_gather_into_tensor(lse_allgather, logsumexp, group=tp_group)\n",
    "            torch.distributed.all_reduce(\n",
    "                losses, op=torch.distributed.ReduceOp.SUM,\n",
    "            )\n",
    "            logsumexp = torch.logsumexp(lse_allgather, dim=0)\n",
    "            losses += logsumexp\n",
    "            losses.masked_fill_(labels.view(-1)==-100, 0.)\n",
    "        ctx.save_for_backward(logits, labels, logsumexp)\n",
    "        ctx.inplace = inplace\n",
    "        ctx.tp_rank = tp_rank\n",
    "        return losses\n",
    "    \n",
    "    @staticmethod\n",
    "    def backward(ctx, dlosses):\n",
    "        logits, labels, logsumexp = ctx.saved_tensors\n",
    "        dlogits = logits if ctx.inplace else torch.zeros_like(logits)\n",
    "        N = logits.size(-1)\n",
    "        logits = logits.view(-1, N)\n",
    "        M = logits.size(0)\n",
    "        vocab_start_index = N * ctx.tp_rank\n",
    "        _cross_entropy_bwd_kernel[(M,)](dlosses, dlogits, \n",
    "                                        logits, labels, logsumexp,\n",
    "                                        vocab_start_index, logits.stride(0),\n",
    "                                        M, N, ctx.inplace, \n",
    "                                        BLOCK_SIZE=16384, num_warps=16, num_stages=4\n",
    "                                        # BLOCK_SIZE=32768, num_warps=32, num_stages=1\n",
    "                                                    )\n",
    "        return dlogits.view(*ctx.input_shape), None, None\n",
    "def triton_entropy_loss(logits, labels, inplace=False):\n",
    "    return _FastCrossEntropyLoss.apply(logits, labels, inplace)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "logits显存占用： 0.6103515625 G\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[54972,  -100,  -100,  ..., 67878,  -100, 40067],\n",
       "        [70641, 29120, 42895,  ...,  -100,  -100,  -100],\n",
       "        [ 6666,  -100,  -100,  ...,  -100,  -100, 68260],\n",
       "        [79595, 72069,  7431,  ...,  -100,  -100, 67588]], device='cuda:0')"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cuda.empty_cache()\n",
    "bs = 4\n",
    "seq_len = 1024\n",
    "vocab_size = 80000\n",
    "dtype=torch.bfloat16\n",
    "device = 'cuda:0'\n",
    "factor = 4 if dtype == torch.float32 else 2\n",
    "print('logits显存占用：',(bs * seq_len * vocab_size) / (1024)**3 * factor,\"G\")\n",
    "logits1 = torch.randn(bs, seq_len, vocab_size, dtype=dtype, device=device)\n",
    "logits1.requires_grad_(True)\n",
    "logits2 = deepcopy(logits1)\n",
    "labels = torch.randint(0, vocab_size-10, (bs, seq_len), device=device)\n",
    "labels = torch.randint(0, vocab_size*2-1, (bs, seq_len)).cuda() - vocab_size\n",
    "labels.masked_fill_(labels<0, -100)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "y1 = triton_entropy_loss(logits1, labels, True)\n",
    "y2 = torch_ce(logits2.view(-1, vocab_size).to(torch.float32), labels.view(-1))\n",
    "dy = torch.rand_like(y1)\n",
    "y1.backward(dy)\n",
    "y2.backward(dy.view(-1))\n",
    "print(torch.allclose(y1.view(-1), y2.view(-1), 1e-4, 1e-4))\n",
    "print(torch.allclose(logits1.grad, logits2.grad, 1e-4, 1e-4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9373115301132202\n",
      "2.3171751499176025\n",
      "13.22261905670166\n"
     ]
    }
   ],
   "source": [
    "print(triton.testing.do_bench(lambda: triton_entropy_loss(logits1, labels)))    # my\n",
    "print(triton.testing.do_bench(lambda: fast_cross_entropy_loss(logits1, labels)))# unsloth\n",
    "print(triton.testing.do_bench(lambda: torch_ce(logits2.view(-1, vocab_size).float(), labels.view(-1)))) #torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7.344710350036621\n",
      "9.676716804504395\n",
      "41.73855209350586\n"
     ]
    }
   ],
   "source": [
    "torch.cuda.empty_cache()\n",
    "def fwd_bwd1():\n",
    "    y = triton_entropy_loss(logits1, labels, inplace=True)\n",
    "    y.sum().backward()\n",
    "def fwd_bwd2():\n",
    "    y = fast_cross_entropy_loss(logits1, labels)\n",
    "    y.sum().backward()\n",
    "def fwd_bwd3():\n",
    "    y = torch_ce(logits1.view(-1, logits1.size(-1)).float(), labels.view(-1))\n",
    "    y.sum().backward()\n",
    "print(triton.testing.do_bench(lambda: fwd_bwd1(), rep=1000)) # my\n",
    "print(triton.testing.do_bench(lambda: fwd_bwd2(), rep=1000)) # unsloth\n",
    "print(triton.testing.do_bench(lambda: fwd_bwd3(), rep=1000)) # torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
